import contextlib
import time
from collections import OrderedDict

import numpy as np
import onnx
import onnx_graphsurgeon
import torch
from PIL import Image


def to_binary_data(path, size=(640, 640), output_name="input_tensor.bin"):
    """--loadInputs='image:input_tensor.bin'"""
    im = Image.open(path).resize(size)
    data = np.asarray(im, dtype=np.float32).transpose(2, 0, 1)[None] / 255.0
    data.tofile(output_name)


def yolo_insert_nms(
    path, score_threshold=0.01, iou_threshold=0.7, max_output_boxes=300, simplify=False
):
    """
    http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/api/onnxops/onnx__EfficientNMS_TRT.html
    https://huggingface.co/spaces/muttalib1326/Punjabi_Character_Detection/blob/3dd1e17054c64e5f6b2254278f96cfa2bf418cd4/utils/add_nms.py
    """
    onnx_model = onnx.load(path)

    if simplify:
        from onnxsim import simplify

        onnx_model, _ = simplify(onnx_model, overwrite_input_shapes={"image": [1, 3, 640, 640]})

    graph = onnx_graphsurgeon.import_onnx(onnx_model)
    graph.toposort()
    graph.fold_constants()
    graph.cleanup()

    topk = max_output_boxes
    attrs = OrderedDict(
        plugin_version="1",
        background_class=-1,
        max_output_boxes=topk,
        score_threshold=score_threshold,
        iou_threshold=iou_threshold,
        score_activation=False,
        box_coding=0,
    )

    outputs = [
        onnx_graphsurgeon.Variable("num_dets", np.int32, [-1, 1]),
        onnx_graphsurgeon.Variable("det_boxes", np.float32, [-1, topk, 4]),
        onnx_graphsurgeon.Variable("det_scores", np.float32, [-1, topk]),
        onnx_graphsurgeon.Variable("det_classes", np.int32, [-1, topk]),
    ]

    graph.layer(
        op="EfficientNMS_TRT",
        name="batched_nms",
        inputs=[graph.outputs[0], graph.outputs[1]],
        outputs=outputs,
        attrs=attrs,
    )

    graph.outputs = outputs
    graph.cleanup().toposort()

    onnx.save(onnx_graphsurgeon.export_onnx(graph), "yolo_w_nms.onnx")


class TimeProfiler(contextlib.ContextDecorator):
    def __init__(
        self,
    ):
        self.total = 0

    def __enter__(
        self,
    ):
        self.start = self.time()
        return self

    def __exit__(self, type, value, traceback):
        self.total += self.time() - self.start

    def reset(
        self,
    ):
        self.total = 0

    def time(
        self,
    ):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        return time.time()
